// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include "fused_moegemm.hpp"
#include "fused_moesorting.hpp"

struct fused_moe_args {
  const void* a_ptr; // [m, k], input token
  const void* a_scale_ptr; // [m, 1], token scale
  const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
  const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
  const void* g_scale_ptr; // [e, 1, n], gate(up) scale
  const void* d_scale_ptr; // [e, 1, k], down scale
  const void*
      y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
  const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
  void* o_ptr; // [m, k], output token (no need to do zeroing)

  const void* topk_ids_ptr; // [tokens, topk]
  const void* topk_weight_ptr; // [tokens, topk]
  void* sorted_token_ids_ptr; // [max_num_tokens_padded]
  void* sorted_weight_ptr; // [max_num_tokens_padded]
  void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) /
                               // block_size]
  void* num_sorted_tiles_ptr; // [1]

  ck_tile::index_t block_m; // block_m, used to devide the input
  ck_tile::index_t hidden_size; // k
  ck_tile::index_t
      intermediate_size; // n / TP, for Gate. and Up, Down is also this value
  ck_tile::index_t num_tokens; // input number of tokens for current iteration
  ck_tile::index_t num_experts; // number of groups
  ck_tile::index_t topk; // need this?

  ck_tile::index_t stride_token; // for input/output, stride for each row,
                                 // should >= hidden_size
};

// This is the public API, will be generated by script
struct fused_moe_traits {
  std::string prec_i; // input precision
  std::string prec_w; // weight precision
  std::string prec_o; // output precision
  std::string prec_st; // token scale data type
  std::string prec_sw; // weight scale data type
  std::string prec_sq; // smooth quant scale
  std::string prec_kw; // topk-weight data type
  int block_m;
  int activation; // 0:gelu, 1:silu
  int gate_only; // 0:g1u0, 1:g1u1
  int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant

  bool local_expert_masking; // if mask experts as local expert
};

float fused_moe(
    fused_moe_traits,
    fused_moe_args,
    const ck_tile::stream_config&);
